-
Notifications
You must be signed in to change notification settings - Fork 13.1k
CUDA: fix FA occupancy, optimize tile kernel #15982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: fix FA occupancy, optimize tile kernel #15982
Conversation
I'll take a proper look at this, but will not be able to do so until the 17th |
Since you're already here, do you have an opinion on whether the HIP backend should be compiled with -ffast-math? |
A quick grep suggests we use inf directly (see softmax), so blanket ffast-math is out. we could use some of the ffast-math flags or use ffast-math on a per function or per translation unit basis, but im not sure its worth it. In the past on other code llvm fast-math has made things slower on hip for some reason, and before rocm 6.1 there where some bugs i have encountered where fast-math just generated plain wrong code. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes look fine to me, i can also confirm the performance delta on gfx1030. When making gfx908 use this code path i cant reproduce the same magnitude of performance improvement as @JohannesGaessler dose on gfx906, but find no regression. I noticed that this pr reduced the amount of spilled vgprs (altho some instances still spill like _ZL15flash_attn_tileILi64ELi32ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiiiiiiiiiiiiiliiliiiiil
) so its possible that some of the extra improvement on gfx906 comes from reduced spills to scratch, where gfx908 can spill to agprs which has a lower performance impact.
Side note:
some of the vector fattn kernels spill to high heaven:
Function Name: _ZL22flash_attn_vec_ext_f32ILi128ELi8EL9ggml_type8ELS0_8ELb1EEvPKcS2_S2_S2_S2_PKiPfP15HIP_vector_typeIfLj2EEffffjfiiiiiiiiiiiiiliiliiiiil
TotalSGPRs: 88
VGPRs: 63
AGPRs: 64
ScratchSize [bytes/lane]: 1100
Dynamic Stack: False
Occupancy [waves/SIMD]: 4
SGPRs Spill: 0
VGPRs Spill: 298
LDS Size [bytes/block]: 10240
thats 362 vector registers spilled in this kernel, as the AGPRs are also spills in this case.
One of my current efforts is to make the kernel parameters more configurable as a function of hardware. I intend to soon procure an RDNA4 GPU so that I can implement support for the AMD WMMA instructions in the mma FA kernel. In principle, if the mma kernel can be made to work it should perform best since you need to hold fewer registers than the tile kernel and unlike the WMMA kernel you don't have to go through shared memory. Can you give me a list of the AMD hardware that you have so that I can adjust my purchases for wider coverage? |
Sure, i have gfx803 (Fiji / GCN3), gfx900 (Vega APU / GCN5), gfx906 (MI50 / GCN5.1), gfx908 (MI100 / CDNA1), gfx1030 (RX6800XT / RDNA2). I dont have any WMMA device at all, so any device with WMMA instructions would be very helpful. I know you dont intend to buy anything for actual use but from a practical perspective the large register file RDNA3 gpus (7900xtx, 7900xt, 7800xt) tend to be better for ai inference than RDNA4, just on account of being bigger devices with more CUs, vram and bandwith. |
@IMbackK I have 7800xt, 7900xt, 7900xtx cards, how can I help you? |
@broadbit-hu not atm. For regression testing it is useful to have people around who regularly run llamacpp on a given arch. But we where talking about doing feature development. When doing feature development the dev in question really needs to have the device with the instructions to be implemented on hand in one of his machines. |
This reverts commit c959b67.
This PR fixes a bug in the CUDA FlashAttention occupancy calculation. In rare cases too few kernels would be launched in parallel, leading to a few % less performance.
This PR also delivers what I think will be the last round of performance optimizations for the tile FA kernel: I revised the memory layout to consistently copy data in 8/16 byte chunks and delayed writing the KQ accumulators to shared memory after they have been compressed to FP16. I looked up the amount of shared memory on each AMD GPU and fit the tile sizes accordingly. One thing that could still be done is do the same GQA optimization as for the mma kernel but because the GPUs using the tile kernel are comparatively slower reducing the mask I/O has little impact; it could improve performance for small batch sizes > 1 though.
Performance changes